import json
import math
import argparse
from tqdm import tqdm
import sys
import time
import torch.multiprocessing as mp
sys.path.append('./')
from utils.logging_utils import setup_logger_to_stdout
from utils.initial_agent import init_agent

logger = setup_logger_to_stdout()


def test_loop(args):
    mp.set_start_method('spawn', force=True)
  
    with open(args.dataset_path, 'rb') as file:
        data = json.load(file)
    # data = data[-10:]
    manager = mp.Manager()
    predResults = manager.list([])
    progressCounter = manager.Value('i', 0)
    listLock = manager.Lock()
    progressLock = manager.Lock()
    status_queue = manager.Queue()

    chunkSize = math.ceil(len(data) / args.num_process)
    chunks = [data[i:i + chunkSize] for i in range(0, len(data), chunkSize)]
    
    try:
        visible_gpus = eval(args.deviceIds)
        agentDeviceCount = math.ceil(len(visible_gpus) / args.num_process)
        deviceIDs = [visible_gpus[i * agentDeviceCount: (i + 1) * agentDeviceCount] for i in range(args.num_process)]
    except:
        logger.warning(f"deviceIDs: {deviceIDs}, type: {type(deviceIDs)}")
      
    loadManager = mp.Manager()
    model_loaded_events = [loadManager.Event() for _ in range(args.num_process)]

    processes = []
    for rank, chunk in enumerate(chunks):
        chunk_start = rank * chunkSize
        p = mp.Process(
            target=test_process, 
            args=(args, rank, deviceIDs[rank], chunk, chunk_start, predResults, listLock, progressCounter, progressLock, model_loaded_events[rank], status_queue),
            daemon=False
        )
        p.start()
        processes.append(p)
    logger.info("Waiting for all agents to finish loading models...")
    for event in model_loaded_events:
        event.wait()
    
    logger.info("All models loaded. Start annotation progress bar.")

    with tqdm(total=len(data), desc=f"predict {args.dataset_path}") as pbar:
        last_progress = 0
        while any(p.is_alive() for p in processes):
            with progressLock:
                current_progress = progressCounter.value

            if current_progress > last_progress:
                pbar.update(current_progress - last_progress)
                last_progress = current_progress
            time.sleep(0.5)
        for p in processes:
            p.join(timeout=1080000)
            if p.is_alive():
                # print(f"Process {p.pid} timed out. Terminating.")
                logger.warning(f"Process {p.pid} timed out. Terminating.")
                p.terminate()
                p.join()
        
        statuses = []
        while not status_queue.empty():
            statuses.append(status_queue.get())
        for sid, status, info in statuses:
            logger.info(f"[Agent {sid}] Status: {status}, Info: {info}")
            
        failed_processes = [s for s in statuses if s[1] != "success"]
        if failed_processes:
            logger.warning(f"\n {len(failed_processes)} agents failed. You may need to retry or debug.")
        
    allPredResults = list(predResults)
    
    return allPredResults

def test_process(args, rank, deviceIDs, chunk, chunk_start, predResults, listLock, progressCounter, progressLock, model_loaded_event, status_queue):
    import os
    visible_devices = ",".join(str(i) for i in deviceIDs)
    os.environ["CUDA_VISIBLE_DEVICES"] = visible_devices
   
    logger.info(f"Agent {rank} sees CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}")

    import torch
    torch.cuda.empty_cache()
    torch.cuda.init()

    # device = torch.device(f"cuda:{visible_devices}" if torch.cuda.is_available() else "cpu")
    logger.info(f"Process {rank}: Using device: {visible_devices}")

    agent = init_agent(args, torch.device('cuda'), True, args.model_name)

    model_loaded_event.set()
    
    chunk_results = []
    try:
        for idx, obs in enumerate(chunk):
            # torch.cuda.empty_cache()
            # torch.cuda.ipc_collect()
            sample_result = {
                "image_path": obs.get("images"),
                "episode_id": obs.get("episode_id"),
                "step_id": obs.get('step_id'),
                "goal": obs.get('goal'),
                "predicted_action": 0,
                "real_action": obs.get("label"),
                "action_type": 0,
                "predicted_action_type": 0,
                "predicted_thought": "",
                "dataset_name": obs.get("dataset_name"),
                "real_thought": "",
                "is_success": False,
                "is_type_match": False,
                "bbox": obs.get("bbox", "")
            }

            image_size = obs.get('image_size')[0]
            sample_result['image_size'] = image_size  
            try:
                ground_truth = agent.res_pre_process.extract_action(obs["label"])
                gt_action_type = agent.res_pre_process.get_action_type(ground_truth)
                sample_result["real_action"] = ground_truth
                sample_result["action_type"] = gt_action_type
                sample_result["real_thought"] = agent.res_pre_process.extract_thought(obs.get("label"))
            except Exception as e_gt:
                logger.warning(f"[GT Error] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: {e_gt}")
                logger.error(f"[GT Exception]: {e_gt}")
                sample_result["error"] = sample_result.get("error", "") + f" | ground_truth: {e_gt}"
            try:
                preds_action_raw = agent.get_action(obs, args)
                preds_action = agent.res_pre_process.extract_action(preds_action_raw)
                check_action = agent.res_pre_process.get_action_type(preds_action)
                sample_result["predicted_action"] = preds_action
                sample_result["predicted_action_type"] = check_action
                sample_result['predicted_thought'] = agent.res_pre_process.extract_thought(preds_action_raw)
            except Exception as e_pred:
                logger.warning(f"[Prediction Error] episode_id={obs.get('episode_id')}, step_id={obs.get('step_id')}: {e_pred}")
                sample_result["error"] = f"prediction: {e_pred}"
            
            chunk_results.append(sample_result)
            with listLock:
                predResults.append(sample_result)
            
            with progressLock:
                progressCounter.value += 1

        status_queue.put((rank, "success", len(chunk_results))) 
        return len(predResults)
    except Exception as e:
        status_queue.put((rank, "error", str(e)))
 
def parse_args():
    parser = argparse.ArgumentParser(description='Testing')
    parser.add_argument('--model_path', type=str, default="/data4/models/GUI-Owl-7B",
                        help='Path to the fine-tuned model. If not provided, the base model will be used.')
    parser.add_argument('--model_name', type=str, default="GUI-Owl-7B",
                        help='model name')
    parser.add_argument('--result_path', type=str, default='/Agent_ScanKit/results/Visual/test.json',
                        help='Path to save the prediction results.')
    parser.add_argument('--dataset_name', type=str, default="AndroidControl",
                        help='dataset name')
    parser.add_argument('--dataset_type', type=str, default='low', help='dataset type')
    parser.add_argument('--dataset_path', type=str, default="/Agent_ScanKit/datasets/json/visual_mask/low/GUI-Owl-7B.json",
                        help='dataset path')
    parser.add_argument('--thought', type=str, default="false",
                        help='w/o or w thought')
    parser.add_argument('--num_process', type=int, default=1,
                        help='num process')
    parser.add_argument('--deviceIds', type=str, default="[1]",
                        help='')
    parser.add_argument('--probing_method', type=str, default="visual_edit",
                        help='')
    parser.add_argument('--mask_object_ratio', type=float, default=50,
                        help='Ratio used for object masking during evaluation.')
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    detailed_results = test_loop(args)
    import torch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = init_agent(args, device, False, args.model_name)
    agent.res_pre_process._res_statistics(args, detailed_results)
    

